package ffte.test

import spinal.core._
import spinal.core.sim._

import breeze.math.Complex
import scala.util.Random._

import ffte.evaluate.cases.{Sim,FFTCase}
import ffte.evaluate.vectors.Cplx

import ffte.core.flat.{FIFFT,FFFT}
import ffte.property.{getShiftMethod,getFFTIW,getFFTOW,FFTIOConfig}

import ffte.algorithm.FFTGen
import ffte.core.streamed.{Streamed,streamedComponents}

object StreamedStudy extends Sim {
    var dW   = getFFTIW()
    var oW   = getFFTOW()
    def valid() = nextBoolean
    def ready() = nextBoolean
    def sim : (Boolean,Int) = {
        C = 100*N
        S   = getShiftMethod(N) + (if(dW>oW) (dW-oW) else 0)
        val fp = if(rgen) (FFTGen.rgen[Streamed](N)) else (FFTGen.gen[Streamed](N))
        val tab     = fp.load_tab
        val out_tab = fp.store_tab 
        var delay = 0
        val tv = test_vectors
        println(s"simulation $N gen test vectors ${getFFTIW()}")
        var ridx = 0
        var fok  = true
        
        SimConfig.compile {
            val fft = fp.gen(new Streamed(S)).asInstanceOf[streamedComponents]
            delay = fft.delay
            fft
        }.doSim{ dut =>
            done({
                idx = 0
                ridx = 0
                failcnt = 0
                dut.clockDomain.forkStimulus(period = 10)
                dut.io.q.ready        #= true
                dut.io.d.payload.last #= true
                dut.clockDomain.waitRisingEdge()
                dut.io.d.payload.last #= false
            },{
                val frame  = idx/N
                val iidx   = idx%N
                val tframe = frame%CC
                val rframe = (idx-delay)/N
                val d : Array[Cplx] = tv(tframe).d
                
                dut.io.d.payload.fragment.d(0).re.d.raw #= d(tab(iidx)).re
                dut.io.d.payload.fragment.d(0).im.d.raw #= d(tab(iidx)).im
                dut.io.d.valid            #= valid()
                dut.io.q.ready            #= ready()
                if(iidx==N-1) {
                    dut.io.d.payload.last #= true
                } else {
                    dut.io.d.payload.last #= false
                } 
                dut.clockDomain.waitRisingEdge()
                if(dut.io.d.ready.toBoolean && dut.io.d.valid.toBoolean) {
                    next
                    if(rframe>1) {
                        val q = tv(rframe%CC).q(out_tab(ridx%N))
                        val r = Cplx(dut.io.q.payload.fragment.d(0).re.d.raw.toInt,dut.io.q.payload.fragment.d(0).im.d.raw.toInt)
                        if((r-q).norm1>32) {
                            if(debug>2) {
                                println(s"$rframe($ridx) $r != $q")
                            }
                            fok = false
                            ok  = false
                        }
                    }
                    if(rframe>=0) {
                        if(dut.io.q.payload.last.toBoolean) {
                            ridx = 0
                            if(!fok) failcnt += 1
                            fok = true
                        } else {
                            ridx += 1
                        }
                    }
                }
            })
        }
        (failcnt!=(C-delay)/N,failcnt)
    }
    def singleTest(n:Int,k:Int):Unit = {
        debug = 3
        N = n
        K = k
        FFTGen.Winograd
        val (_ook,failcnt) = FFTIOConfig(dW,oW) on sim
        if(_ook) println(s"$n($inv): PASS($failcnt)") else println(s"$n: FAIL")
        println(s"shift:$S")
    }
    def caseTest(cases:Map[Int,Int]): Unit = {
        K = 0
        debug = 1
        val nList = cases.map{ case(n,_) => n }.toList.sorted
        nList.map { n =>
            N = n
            println(s"start $N test")
            val (_ook,failcnt) = sim
            val msg = if(_ook) s"$n($inv): PASS($failcnt)" else s"$n: FAIL"
            Sim.log2file("tmp/streamed.log",msg)
            (n,_ook,failcnt)
        }.foreach{case (n,oook,failcnt) =>
            if(oook) println(s"$n($inv): PASS($failcnt)") else println(s"$n: FAIL")
        }
    }
    def main(args: Array[String]) {
        ffte.default.core.GlobeSetting.evaluateNative = false
        rgen = if(args.length>0) (args(0).toInt==1) else false 
        if(args.length>1) {
            val n = args(1).toInt
            val k = args(2).toInt
            if(args.length>3) {
                dW = args(3).toInt
            }
            if(args.length>4) {
                oW = args(4).toInt
            }
            singleTest(n,k)
        } else caseTest(FFTCase.full)
    }
}